{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-beta1\n",
      "sys.version_info(major=3, minor=7, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.0.3\n",
      "numpy 1.16.2\n",
      "pandas 0.24.2\n",
      "sklearn 0.20.3\n",
      "tensorflow 2.0.0-beta1\n",
      "tensorflow.python.keras.api._v2.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, tf, keras:\n",
    "    print(module.__name__, module.__version__)\n",
    "\n",
    "    \n",
    "'''\n",
    "多种转化关系：从其他各种类型转化为 tflite，并写入文件中\n",
    "1. 具体函数转化 tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])\n",
    "2. savedmodel转化 tf.lite.TFLiteConverter.from_saved_model('./keras_saved_graph/')\n",
    "3. keras model模型转化为 tflite tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)\n",
    "\n",
    "大致步骤是：\n",
    "1. 产生一个 converter\n",
    "2. converter对象调用 convert()方法，产生tflite\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0714 19:35:08.749957 140736297124800 deprecation.py:323] From /Users/zhangyx/workspace/environments/tf2_py3/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=497, shape=(1, 10), dtype=float32, numpy=\n",
       "array([[0.20111723, 0.09845883, 0.02821723, 0.0199804 , 0.03568273,\n",
       "        0.07881083, 0.10847212, 0.02893549, 0.34842488, 0.05190022]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## keras model模型转化为 tflite\n",
    "## 先载入model\n",
    "loaded_keras_model = keras.models.load_model(\n",
    "    './graph_def_and_weights/fashion_mnist_model.h5')\n",
    "\n",
    "## 测试model是否正常\n",
    "## np.ones((shape))，产生满足shape的 np.array 数据全是1\n",
    "loaded_keras_model(np.ones((1, 28, 28)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 从keras model中生成 converter 转换器对象\n",
    "keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(\n",
    "    loaded_keras_model)\n",
    "# 调用converter转换器对象的 convert()方法，获取tflite\n",
    "keras_tflite = keras_to_tflite_converter.convert()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 将上面产生的tflite model给保存到文件中\n",
    "if not os.path.exists('./tflite_models'):\n",
    "    os.mkdir('./tflite_models')\n",
    "## 写入文件\n",
    "with open('./tflite_models/keras_tflite', 'wb') as f:\n",
    "    f.write(keras_tflite)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "## 定义tf函数对象，就相当于前面的 普通函数转换为tf函数\n",
    "run_model = tf.function(lambda x : loaded_keras_model(x))\n",
    "## 从tf函数中调用get_concrete_function 得到具体函数\n",
    "keras_concrete_func = run_model.get_concrete_function(\n",
    "    tf.TensorSpec(loaded_keras_model.inputs[0].shape,\n",
    "                  loaded_keras_model.inputs[0].dtype))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=650, shape=(1, 10), dtype=float32, numpy=\n",
       "array([[0.20111723, 0.09845883, 0.02821723, 0.0199804 , 0.03568273,\n",
       "        0.07881083, 0.10847212, 0.02893549, 0.34842488, 0.05190022]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## 测试具体函数\n",
    "keras_concrete_func(tf.constant(np.ones((1, 28, 28), dtype=np.float32)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 具体函数就能够被存储成 converter转换器\n",
    "concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])\n",
    "## 转换器 调用convert方法获取 tflite\n",
    "concrete_func_tflite = concrete_func_to_tflite_converter.convert()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./tflite_models/concrete_func_tflite', 'wb') as f:\n",
    "    f.write(concrete_func_tflite)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "## 把savedmodel 给转化为 converter\n",
    "saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('./keras_saved_graph/')\n",
    "## 进而转化为 tflite\n",
    "saved_model_tflite = saved_model_to_tflite_converter.convert()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./tflite_models/saved_model_tflite', 'wb') as f:\n",
    "    f.write(saved_model_tflite)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
